/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the
 * "License"). Please refer to the License for details. You may not use this
 * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN
 * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS
 * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository
 * for the full text of the License.
 */

#ifndef EXAMPLES_NORMALIZATION_NORMALIZE_CUSTOM_H
#define EXAMPLES_NORMALIZATION_NORMALIZE_CUSTOM_H
#include "kernel_operator.h"

namespace NormalizeCustomKernel {
struct NormalizeTiling {
    uint32_t aLength;
    uint32_t rLength;
    uint32_t rLengthWithPadding;
    uint32_t tmpLocalSize;
};

constexpr uint8_t LOCAL_BYTES = 32;

template <typename T, typename U, bool isReuseSource = false>
class KernelNormalize {
public:
    __aicore__ inline KernelNormalize() {}
    __aicore__ inline void Init(GM_ADDR inputX_gm, GM_ADDR inputMean_gm, GM_ADDR inputVar_gm, GM_ADDR gamma_gm,
        GM_ADDR beta_gm, GM_ADDR output_gm, GM_ADDR outputRstd_gm, NormalizeTiling tilingData) {
        aLength = tilingData.aLength;
        rLength = tilingData.rLength;
        rLengthWithPadding = tilingData.rLengthWithPadding;
        tmpLocalBytes = tilingData.tmpLocalSize;
        uint32_t totalLength = aLength * rLengthWithPadding;
        inputX_global.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(inputX_gm), totalLength);                // [A, R]
        inputMean_global.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(inputMean_gm), aLength);          // [A]
        inputVar_global.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(inputVar_gm), aLength);            // [A]
        inputGamma_global.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(gamma_gm), rLengthWithPadding);  // [R]
        inputBeta_global.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(beta_gm), rLengthWithPadding);    // [R]

        output_global.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(output_gm), totalLength);
        outputRstd_global.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(outputRstd_gm), aLength);

        pipe.InitBuffer(inQueueX, 1, sizeof(T) * totalLength);
        pipe.InitBuffer(inQueueMean, 1, sizeof(float) * aLength);
        pipe.InitBuffer(inQueueVar, 1, sizeof(float) * aLength);
        pipe.InitBuffer(inQueueGamma, 1, sizeof(float) * rLengthWithPadding);
        pipe.InitBuffer(inQueueBeta, 1, sizeof(float) * rLengthWithPadding);

        pipe.InitBuffer(outQueue, 1, sizeof(T) * totalLength);
        pipe.InitBuffer(outQueueRstd, 1, sizeof(float) * aLength);
    }
    __aicore__ inline void Process() {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn() {
        AscendC::LocalTensor<T> inputXLocal = inQueueX.AllocTensor<T>();
        AscendC::LocalTensor<T> inMeanLocal = inQueueMean.AllocTensor<float>();
        AscendC::LocalTensor<T> inVarLocal = inQueueVar.AllocTensor<float>();
        AscendC::LocalTensor<U> inGammaLocal = inQueueGamma.AllocTensor<U>();
        AscendC::LocalTensor<U> inBetaLocal = inQueueBeta.AllocTensor<U>();

        AscendC::DataCopy(inputXLocal, inputX_global, aLength * rLengthWithPadding);
        AscendC::DataCopy(inMeanLocal, inputMean_global, aLength);
        AscendC::DataCopy(inVarLocal, inputVar_global, aLength);
        AscendC::DataCopy(inGammaLocal, inputGamma_global, rLengthWithPadding);
        AscendC::DataCopy(inBetaLocal, inputBeta_global, rLengthWithPadding);

        inQueueX.EnQue(inputXLocal);
        inQueueMean.EnQue(inMeanLocal);
        inQueueVar.EnQue(inVarLocal);
        inQueueGamma.EnQue(inGammaLocal);
        inQueueBeta.EnQue(inBetaLocal);
    }

    __aicore__ inline void Compute() {
        AscendC::LocalTensor<T> inputXLocal = inQueueX.DeQue<T>();
        AscendC::LocalTensor<float> inputMeanLocal = inQueueMean.DeQue<float>();
        AscendC::LocalTensor<float> inputVarLocal = inQueueVar.DeQue<float>();
        AscendC::LocalTensor<U> inputGammaLocal = inQueueGamma.DeQue<U>();
        AscendC::LocalTensor<U> inputBetaLocal = inQueueBeta.DeQue<U>();

        AscendC::LocalTensor<T> outLocal = outQueue.AllocTensor<T>();
        AscendC::LocalTensor<float> outRstdLocal = outQueueRstd.AllocTensor<float>();

        float epsilon = 0.001;
        AscendC::NormalizePara para = {aLength, rLength, rLengthWithPadding};
        static constexpr AscendC::NormalizeConfig config = AscendC::GetNormalizeConfig(false, false);
        AscendC::LocalTensor<uint8_t> sharedTmpBuffer;
        bool ans = AscendC::PopStackBuffer<uint8_t, AscendC::TPosition::LCM>(sharedTmpBuffer);
        sharedTmpBuffer.SetSize(tmpLocalBytes);
        AscendC::Normalize<U, T, false, config>(outLocal, outRstdLocal, inputMeanLocal, inputVarLocal, inputXLocal,
            inputGammaLocal, inputBetaLocal, sharedTmpBuffer, epsilon, para);

        outQueue.EnQue(outLocal);
        outQueueRstd.EnQue(outRstdLocal);
        inQueueX.FreeTensor(inputXLocal);
        inQueueMean.FreeTensor(inputMeanLocal);
        inQueueVar.FreeTensor(inputVarLocal);
        inQueueGamma.FreeTensor(inputGammaLocal);
        inQueueBeta.FreeTensor(inputBetaLocal);
    }
    __aicore__ inline void CopyOut() {
        AscendC::LocalTensor<U> outLocal = outQueue.DeQue<T>();
        AscendC::LocalTensor<U> outRstdLocal = outQueueRstd.DeQue<float>();

        AscendC::DataCopy(output_global, outLocal, aLength * rLengthWithPadding);
        AscendC::DataCopy(outputRstd_global, outRstdLocal, aLength);

        outQueue.FreeTensor(outLocal);
        outQueueRstd.FreeTensor(outRstdLocal);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueX;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueMean;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueVar;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueGamma;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueueBeta;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueue;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueueRstd;

    AscendC::GlobalTensor<T> inputX_global;
    AscendC::GlobalTensor<float> inputMean_global;
    AscendC::GlobalTensor<float> inputVar_global;
    AscendC::GlobalTensor<U> inputGamma_global;
    AscendC::GlobalTensor<U> inputBeta_global;
    AscendC::GlobalTensor<T> output_global;
    AscendC::GlobalTensor<float> outputRstd_global;

    uint32_t tmpLocalBytes = 0;
    uint32_t aLength;
    uint32_t rLength;
    uint32_t rLengthWithPadding;
};

} // namespace NormalizeCustomKernel

#endif // EXAMPLES_NORMALIZATION_NORMALIZE_CUSTOM_H
